-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] make GroupedLinear inp support collection of torch.Tensor #1120
Conversation
Yeap, I notice that. My plan is to implement a fused kernel to deal multi By the way, have you consider implement a fused kernel to eliminate call of Thanks a lot ! |
Not yet. It's a good idea. Maybe I can have a try. I'm working on enabling FP8 support in MCore's |
Hi @BeingGod, do you still need this feature with the padding method I pasted above? |
Hi @yaox12 , Seems our work have some conflict. I'm trying to fuse Can you help me review this solution ? |
@BeingGod Looks reasonable to me. |
Description
For FP8 GroupedMLP linear_fc1, to make sure Tensor shape is aligned by 16 we will split activation and pad each tensor then concat list of Tensor as GroupedLinear
inp
args.e.g.
In _GroupedLinear it will split
inp
basem_splits
. Actually the cat and split is duplicated in this place. We hope _GroupedLinear can acceptinp
as a collection of Tensor (e.g List[torch.Tensor] or Tuple[torch.Tensor]) to reduce 2 * cat kernel call (1 * forward + 1 * backward).profiling:
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
inp
support collection of torch.TensorChecklist: